# 输入 3*3*3，输出5*2*2

import torch
import torch.nn as nn

# 1. 构造输入
x = torch.arange(1, 28, dtype=torch.float32).reshape(1, 3, 3, 3)
# x.shape = (1,3,3,3)

# 2. 定义卷积：输入通道3 → 输出通道5，核 2×2，步长1，填充0
conv = nn.Conv2d(in_channels=3, out_channels=5,
                 kernel_size=2, stride=1, padding=0, bias=True)

# 3. 手动把权重/偏置设成“看得见”的值
#    权重形状 (5,3,2,2) 共 5×3×2×2 = 60 个数
#    偏置形状 (5,)
with torch.no_grad():
    conv.weight.fill_(0.1)  # 所有权重=0.1
    conv.bias.copy_(torch.arange(1, 6, dtype=torch.float32))
    # bias = [1,2,3,4,5]

# 4. 前向计算
y = conv(x)  # 输出形状 (1,5,2,2)

print("输入:")
print(x)
# tensor([[[[ 1.,  2.,  3.],
#           [ 4.,  5.,  6.],
#           [ 7.,  8.,  9.]],
#
#          [[10., 11., 12.],
#           [13., 14., 15.],
#           [16., 17., 18.]],
#
#          [[19., 20., 21.],
#           [22., 23., 24.],
#           [25., 26., 27.]]]])

print("\n输出:")
print(y)
"""
手算验证（以输出第 0 通道左上角值 8.8 为例）
该位置对应的输入 patch 为：
[ 1, 2, 4, 5 ]      通道 0
[10,11,13,14]      通道 1
[19,20,22,23]      通道 2

每个元素乘 0.1，再求和，再加偏置 1：
(1+2+4+5+10+11+13+14+19+20+22+23) * 0.1 + 1 = 14.4 + 1 = 15.4
"""
# tensor([[[[15.4000, 16.6000],
#           [19.0000, 20.2000]],
#
#          [[16.4000, 17.6000],
#           [20.0000, 21.2000]],
#
#          [[17.4000, 18.6000],
#           [21.0000, 22.2000]],
#
#          [[18.4000, 19.6000],
#           [22.0000, 23.2000]],
#
#          [[19.4000, 20.6000],
#           [23.0000, 24.2000]]]], grad_fn=<ConvolutionBackward0>)
